{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c290621d-118d-4bc7-a9d6-31ae560b914d",
   "metadata": {},
   "source": [
    "# ReLU-KAN Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2524b46-1a53-4805-a0a8-4e29fac5c7f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import optax\n",
    "from flax import linen as nn\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "import time\n",
    "import scipy\n",
    "\n",
    "# Add /src to path\n",
    "path_to_src = os.path.abspath(os.path.join(os.getcwd(), '../../../src'))\n",
    "if path_to_src not in sys.path:\n",
    "    sys.path.append(path_to_src)\n",
    "\n",
    "from ReLUKAN import ReLUKAN\n",
    "from PIKAN import *\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6063fffe-caab-4999-9b55-a468b94cc7ee",
   "metadata": {},
   "source": [
    "### Collocation Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0067fb99-31f2-4a07-a8f8-f6d7640e557f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Collocation points for PDE\n",
    "N = 2**12\n",
    "collocs = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,1]), N)) # (4096, 2)\n",
    "\n",
    "# Generate Collocation points for BCs\n",
    "N = 2**6\n",
    "\n",
    "BC1_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([-1,1]), N)) # (64, 2)\n",
    "BC1_data = jnp.zeros(BC1_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC2_colloc = jnp.array(sobol_sample(np.array([-1,-1]), np.array([1,-1]), N)) # (64, 2)\n",
    "BC2_data = jnp.zeros(BC2_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC3_colloc = jnp.array(sobol_sample(np.array([1,-1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC3_data = jnp.zeros(BC3_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "BC4_colloc = jnp.array(sobol_sample(np.array([-1,1]), np.array([1,1]), N)) # (64, 2)\n",
    "BC4_data = jnp.zeros(BC4_colloc.shape[0]).reshape(-1,1) # (64, 1)\n",
    "\n",
    "# Create lists for BCs\n",
    "bc_collocs = [BC1_colloc, BC2_colloc, BC3_colloc, BC4_colloc]\n",
    "bc_data = [BC1_data, BC2_data, BC3_data, BC4_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c78af137-e282-49c7-a231-5b7a163ea015",
   "metadata": {},
   "source": [
    "## Loss Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4ae7d8-b3f0-44d5-8ecc-5a398528ae94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pde_loss(params, collocs, state):\n",
    "    # Eq. parameters\n",
    "    k = jnp.array(1.0, dtype=float)\n",
    "    a1 = jnp.array(1.0, dtype=float)\n",
    "    a2 = jnp.array(4.0, dtype=float)\n",
    "    \n",
    "    # Define the model function\n",
    "    variables = {'params' : params, 'state' : state}\n",
    "    \n",
    "    def u(vec_x):\n",
    "        y, spl = model.apply(variables, vec_x)\n",
    "        return y\n",
    "        \n",
    "    # Physics Loss Terms\n",
    "    u_xx = gradf(u, 0, 2)  # 2nd order derivative of x\n",
    "    u_yy = gradf(u, 1, 2) # 2nd order derivative of y\n",
    "\n",
    "    sines = jnp.sin(a1*jnp.pi*collocs[:,[0]])*jnp.sin(a2*jnp.pi*collocs[:,[1]])\n",
    "    source = -((a1*jnp.pi)**2)*sines - ((a2*jnp.pi)**2)*sines + k*sines\n",
    "    \n",
    "    # Residual\n",
    "    pde_res = u_xx(collocs) + u_yy(collocs) + (k**2)*u(collocs) - source\n",
    "    \n",
    "    return pde_res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "537797e7-95c9-468a-b8ea-0be8ef80b88f",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02f8d4aa-29c7-4ee2-a3fa-74ea7fe6d339",
   "metadata": {},
   "source": [
    "### Training Static Case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6571fe95-8644-4cc5-86a5-07420116389b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=1.0, const_res=0.0, add_bias=True, grid_e=1.0)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "grid_adapt = []\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),\n",
    "          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),\n",
    "         jnp.ones((BC4_colloc.shape[0],1))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e73963f-6eb3-4475-940d-3c10532b19fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100000\n",
    "\n",
    "model, variables, train_losses = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6904204-b21c-44f9-b9b6-7833376cf783",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also calculate analytical solution\n",
    "def helm_exact(x,y):\n",
    "    a1 = 1.0\n",
    "    a2 = 4.0\n",
    "    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)\n",
    "\n",
    "N_x, N_y = 100, 256\n",
    "\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "y = np.linspace(-1.0, 1.0, N_y)\n",
    "X, Y = np.meshgrid(x, y, indexing='ij')\n",
    "coords = np.stack([X.flatten(), Y.flatten()], axis=1)\n",
    "\n",
    "ref = helm_exact(X, Y)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "static = np.array(output).reshape(N_x, N_y)\n",
    "\n",
    "l2err = jnp.linalg.norm(static-ref)/jnp.linalg.norm(ref)\n",
    "print(f\"L^2 Error = {l2err*100:.4f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb710f8-f319-4d37-94cf-d0633dfa2c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f2926d7-701d-468a-8c83-30a0397c7613",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(X, Y, np.abs(static-ref), shading='auto', cmap='Spectral_r') #\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Absolute Error for Helmholtz Equation')\n",
    "plt.xlabel('x')\n",
    "\n",
    "plt.ylabel('y')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f89288e-fd52-42af-bd5f-978c685a2da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "np.savez('../Plots/data/relu1.npz', x=x, y=y, res=static, ref=ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d2d195c-43fe-4d2a-85f2-a4951671c842",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a3ce495b-7a7c-4da9-ad1a-fefebbad5a8f",
   "metadata": {},
   "source": [
    "### Training Non-fully Adaptive Case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34e8921-a811-41fe-adca-faf605b69bdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=1.0, const_res=0.0, add_bias=True, grid_e=1.0)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0, 8_000 : 0.5, 15_000 : 0.5, 30_000 : 0.4, 50_000 : 0.7, 70_000 : 0.7}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "grid_adapt = []\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3, 20_000 : 6, 35_000 : 12}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),\n",
    "          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),\n",
    "         jnp.ones((BC4_colloc.shape[0],1))]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2335634-6afd-4c8a-8989-6111518b2948",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100_000\n",
    "\n",
    "model, variables, train_losses2 = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "174a6240-227d-423f-93c0-235156bf47ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also calculate analytical solution\n",
    "def helm_exact(x,y):\n",
    "    a1 = 1.0\n",
    "    a2 = 4.0\n",
    "    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)\n",
    "\n",
    "N_x, N_y = 100, 256\n",
    "\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "y = np.linspace(-1.0, 1.0, N_y)\n",
    "X, Y = np.meshgrid(x, y, indexing='ij')\n",
    "coords = np.stack([X.flatten(), Y.flatten()], axis=1)\n",
    "\n",
    "ref = helm_exact(X, Y)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "nonfullya = np.array(output).reshape(N_x, N_y)\n",
    "\n",
    "l2err = jnp.linalg.norm(nonfullya-ref)/jnp.linalg.norm(ref)\n",
    "print(f\"L^2 Error = {l2err*100:.4f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a71d1b7a-b9ac-48ae-b558-b058465bcc3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses2), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acc738a8-e2e9-4ab1-92cc-dc156d22feb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(X, Y, np.abs(nonfullya-ref), shading='auto', cmap='Spectral_r') #\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Absolute Error for Helmholtz Equation')\n",
    "plt.xlabel('x')\n",
    "\n",
    "plt.ylabel('y')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69622a13-49a8-4e82-843d-3ba3a5b84ddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "grids = np.array(list(grid_extend.keys()))\n",
    "np.savez('../Plots/data/relu2.npz', x=x, y=y, res=nonfullya, ref=ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ad79c1-a37e-4094-802a-c4f70b3cc067",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b77c622d-f4b4-4bc5-962c-61c8664e41ab",
   "metadata": {},
   "source": [
    "### Training Fully Adaptive Case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "110ffb23-759b-4d8b-abdf-4b3adc2f8bb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "layer_dims = [2, 8, 8, 1]\n",
    "model = ReLUKAN(layer_dims=layer_dims, p=2, k=2, const_R=False, const_res=False, add_bias=True, grid_e=0.05)\n",
    "variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 2]))\n",
    "\n",
    "# Define learning rates for scheduler\n",
    "lr_vals = dict()\n",
    "lr_vals['init_lr'] = 0.001\n",
    "lr_vals['scales'] = {0 : 1.0, 20_000 : 0.6, 35_000 : 0.8, 50_000 : 0.7, 70_000 : 0.7}\n",
    "\n",
    "# Define epochs for grid adaptation\n",
    "adapt_every = 500\n",
    "adapt_stop = 40000\n",
    "grid_adapt = [i * adapt_every for i in range(1, (adapt_stop // adapt_every) + 1)]\n",
    "grid_adapt = []\n",
    "\n",
    "# Define epochs for grid extension, along with grid sizes\n",
    "grid_extend = {0 : 3, 20_000 : 6, 35_000 : 12}\n",
    "\n",
    "# Define global loss weights\n",
    "glob_w = [jnp.array(0.01, dtype=float), jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float),\n",
    "          jnp.array(1.0, dtype=float), jnp.array(1.0, dtype=float)]\n",
    "\n",
    "# Initialize RBA weights\n",
    "loc_w = [jnp.ones((collocs.shape[0],1)), jnp.ones((BC1_colloc.shape[0],1)),\n",
    "         jnp.ones((BC2_colloc.shape[0],1)), jnp.ones((BC3_colloc.shape[0],1)),\n",
    "         jnp.ones((BC4_colloc.shape[0],1))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09578531-4731-4925-b16d-32c7cdcdac93",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100_000\n",
    "\n",
    "model, variables, train_losses3 = train_PIKAN(model, variables, pde_loss, collocs, bc_collocs, bc_data, glob_w=glob_w, \n",
    "                                             lr_vals=lr_vals, adapt_state=True, loc_w=loc_w, nesterov=True, \n",
    "                                             num_epochs=num_epochs, grid_extend=grid_extend, grid_adapt=grid_adapt, \n",
    "                                             colloc_adapt={'epochs' : []})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b02855c-cf18-4b3e-afe2-85050d3c3bb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Also calculate analytical solution\n",
    "def helm_exact(x,y):\n",
    "    a1 = 1.0\n",
    "    a2 = 4.0\n",
    "    return np.sin(a1*np.pi*x)*np.sin(a2*np.pi*y)\n",
    "\n",
    "N_x, N_y = 100, 256\n",
    "\n",
    "x = np.linspace(-1.0, 1.0, N_x)\n",
    "y = np.linspace(-1.0, 1.0, N_y)\n",
    "X, Y = np.meshgrid(x, y, indexing='ij')\n",
    "coords = np.stack([X.flatten(), Y.flatten()], axis=1)\n",
    "\n",
    "ref = helm_exact(X, Y)\n",
    "\n",
    "output, _ = model.apply(variables, jnp.array(coords))\n",
    "fullya = np.array(output).reshape(N_x, N_y)\n",
    "\n",
    "l2err = jnp.linalg.norm(fullya-ref)/jnp.linalg.norm(ref)\n",
    "print(f\"L^2 Error = {l2err*100:.4f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bb37ae3-7e3d-445b-81b7-8a43ab9b6802",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "\n",
    "plt.plot(np.array(train_losses3), label='Train Loss', marker='o', color='blue', markersize=1)\n",
    "\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Loss')\n",
    "plt.title('Training Loss Over Epochs')\n",
    "plt.yscale('log')  # Set y-axis to logarithmic scale\n",
    "\n",
    "plt.legend()\n",
    "plt.grid(True, which='both', linestyle='--', linewidth=0.5) \n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0068acc7-298c-481f-8230-116158659fcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 5))\n",
    "plt.pcolormesh(X, Y, np.abs(fullya-ref), shading='auto', cmap='Spectral_r') #\n",
    "plt.colorbar()\n",
    "\n",
    "plt.title('Absolute Error for Helmholtz Equation')\n",
    "plt.xlabel('x')\n",
    "\n",
    "plt.ylabel('y')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1468d1-147d-42b8-a0e7-e666650998e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = np.arange(num_epochs)\n",
    "grids = np.array(list(grid_extend.keys()))\n",
    "np.savez('../Plots/data/relu3.npz', x=x, y=y, res=fullya, ref=ref)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcdb78e0-1510-4443-8989-9d42d337fa5b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
